{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "04160a3f",
   "metadata": {},
   "source": [
    "# Tutorial16: DIFFPOOL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a8d78493",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4643b351",
   "metadata": {},
   "source": [
    "Below are shown the computation to obtain the nodes features matrix and adjacency matrix for the first hierarchical step. \n",
    "\n",
    "Initial graph: \n",
    "```x_0   = 50 x 32\n",
    "adj_0  = 50 x 50```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "82ee82c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Node features matrix\n",
    "x_0 = torch.rand(50, 32)\n",
    "adj_0 = torch.rand(50,50).round().long()\n",
    "identity = torch.eye(50)\n",
    "adj_0 = adj_0 + identity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28024fda",
   "metadata": {},
   "source": [
    "Set the number of clusters we want to obtain at step 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bd302dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_clusters_0 = 50\n",
    "n_clusters_1 = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a49c2a5",
   "metadata": {},
   "source": [
    "Initialize the weights of GNN_emb and GNN_pool, we use just 1 conv layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3ff27350",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_gnn_emb = torch.rand(32, 16)\n",
    "w_gnn_pool = torch.rand(32, n_clusters_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "106f6517",
   "metadata": {},
   "source": [
    "<img src=\"img1.png\" width=300px>\n",
    "<img src=\"img2.png\" width=400px>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d5bf26cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0 = torch.relu(adj_0 @ x_0 @ w_gnn_emb)\n",
    "s_0 = torch.softmax(torch.relu(adj_0 @ x_0 @ w_gnn_pool), dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ff679f1",
   "metadata": {},
   "source": [
    "<img src=\"img3.png\" width=200px>\n",
    "<img src=\"img4.png\" width=200px>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7fdb0427",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_1 = s_0.t() @ z_0\n",
    "adj_1 = s_0.t() @ adj_0 @ s_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "610fa823",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 16])\n",
      "torch.Size([5, 5])\n"
     ]
    }
   ],
   "source": [
    "print(x_1.shape)\n",
    "print(adj_1.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6b7fbcae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "from math import ceil\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.data import DenseDataLoader\n",
    "from torch_geometric.nn import DenseGCNConv as GCNConv, dense_diff_pool\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c47ebf13",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_nodes = 150\n",
    "\n",
    "\n",
    "class MyFilter(object):\n",
    "    def __call__(self, data):\n",
    "        return data.num_nodes <= max_nodes\n",
    "\n",
    "\n",
    "dataset = TUDataset('data', name='PROTEINS', transform=T.ToDense(max_nodes),\n",
    "                    pre_filter=MyFilter())\n",
    "dataset = dataset.shuffle()\n",
    "n = (len(dataset) + 9) // 10\n",
    "test_dataset = dataset[:n]\n",
    "val_dataset = dataset[n:2 * n]\n",
    "train_dataset = dataset[2 * n:]\n",
    "test_loader = DenseDataLoader(test_dataset, batch_size=32)\n",
    "val_loader = DenseDataLoader(val_dataset, batch_size=32)\n",
    "train_loader = DenseDataLoader(train_dataset, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d2efbcfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels,\n",
    "                 normalize=False, lin=True):\n",
    "        super(GNN, self).__init__()\n",
    "        \n",
    "        self.convs = torch.nn.ModuleList()\n",
    "        self.bns = torch.nn.ModuleList()\n",
    "        \n",
    "        self.convs.append(GCNConv(in_channels, hidden_channels, normalize))\n",
    "        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n",
    "        \n",
    "        self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))\n",
    "        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n",
    "        \n",
    "        self.convs.append(GCNConv(hidden_channels, out_channels, normalize))\n",
    "        self.bns.append(torch.nn.BatchNorm1d(out_channels))\n",
    "\n",
    "\n",
    "    def forward(self, x, adj, mask=None):\n",
    "        batch_size, num_nodes, in_channels = x.size()\n",
    "        \n",
    "        for step in range(len(self.convs)):\n",
    "            x = self.bns[step](F.relu(self.convs[step](x, adj, mask)))\n",
    "        \n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class DiffPool(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DiffPool, self).__init__()\n",
    "\n",
    "        num_nodes = ceil(0.25 * max_nodes)\n",
    "        self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes)\n",
    "        self.gnn1_embed = GNN(dataset.num_features, 64, 64)\n",
    "\n",
    "        num_nodes = ceil(0.25 * num_nodes)\n",
    "        self.gnn2_pool = GNN(64, 64, num_nodes)\n",
    "        self.gnn2_embed = GNN(64, 64, 64, lin=False)\n",
    "\n",
    "        self.gnn3_embed = GNN(64, 64, 64, lin=False)\n",
    "\n",
    "        self.lin1 = torch.nn.Linear(64, 64)\n",
    "        self.lin2 = torch.nn.Linear(64, dataset.num_classes)\n",
    "\n",
    "    def forward(self, x, adj, mask=None):\n",
    "        s = self.gnn1_pool(x, adj, mask)\n",
    "        x = self.gnn1_embed(x, adj, mask)\n",
    "\n",
    "        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)\n",
    "        #x_1 = s_0.t() @ z_0\n",
    "        #adj_1 = s_0.t() @ adj_0 @ s_0\n",
    "        \n",
    "        s = self.gnn2_pool(x, adj)\n",
    "        x = self.gnn2_embed(x, adj)\n",
    "\n",
    "        x, adj, l2, e2 = dense_diff_pool(x, adj, s)\n",
    "\n",
    "        x = self.gnn3_embed(x, adj)\n",
    "\n",
    "        x = x.mean(dim=1)\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = self.lin2(x)\n",
    "        return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "f44e6f5b",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "running_mean should contain 150 elements not 64",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-26-6a7a0707c268>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0mbest_val_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m151\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m     \u001b[0mtrain_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m     \u001b[0mval_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     37\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mval_acc\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mbest_val_acc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-26-6a7a0707c268>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch)\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m         \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/pyg/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1052\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1053\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1054\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1055\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1056\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-25-e55db94e19f9>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, adj, mask)\u001b[0m\n\u001b[1;32m     53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m         \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgnn1_pool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     56\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgnn1_embed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/pyg/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1052\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1053\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1054\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1055\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1056\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-25-e55db94e19f9>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, adj, mask)\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m             \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/pyg/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1052\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1053\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1054\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1055\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1056\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/pyg/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    165\u001b[0m         \u001b[0mused\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnormalization\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval\u001b[0m \u001b[0mmode\u001b[0m \u001b[0mwhen\u001b[0m \u001b[0mbuffers\u001b[0m \u001b[0mare\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m         \"\"\"\n\u001b[0;32m--> 167\u001b[0;31m         return F.batch_norm(\n\u001b[0m\u001b[1;32m    168\u001b[0m             \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    169\u001b[0m             \u001b[0;31m# If buffers are not to be tracked, ensure that they won't be updated\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/pyg/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m   2279\u001b[0m         \u001b[0m_verify_batch_size\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2281\u001b[0;31m     return torch.batch_norm(\n\u001b[0m\u001b[1;32m   2282\u001b[0m         \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_var\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2283\u001b[0m     )\n",
      "\u001b[0;31mRuntimeError\u001b[0m: running_mean should contain 150 elements not 64"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = DiffPool().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "\n",
    "def train(epoch):\n",
    "    model.train()\n",
    "    loss_all = 0\n",
    "\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output, _, _ = model(data.x, data.adj, data.mask)\n",
    "        loss = F.nll_loss(output, data.y.view(-1))\n",
    "        loss.backward()\n",
    "        loss_all += data.y.size(0) * loss.item()\n",
    "        optimizer.step()\n",
    "    return loss_all / len(train_dataset)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]\n",
    "        correct += pred.eq(data.y.view(-1)).sum().item()\n",
    "    return correct / len(loader.dataset)\n",
    "\n",
    "\n",
    "best_val_acc = test_acc = 0\n",
    "for epoch in range(1, 151):\n",
    "    train_loss = train(epoch)\n",
    "    val_acc = test(val_loader)\n",
    "    if val_acc > best_val_acc:\n",
    "        test_acc = test(test_loader)\n",
    "        best_val_acc = val_acc\n",
    "    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '\n",
    "          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a485b54",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eebb307b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
